from matplotlib import pyplot as plt
import numpy as np
import pickle

# f = open("./Model2/AwA2/alex_data.bin", "rb")
f = open("./Model2/AwA2/resnet_data_10e3.bin", "rb")
data = pickle.load(f)
f.close()

train_loss, train_acc, test_loss, test_acc = zip(*data)
train_loss = np.array(train_loss)
train_acc = np.array(train_acc) * 100
test_loss = np.array(test_loss)
test_acc = np.array(test_acc) * 100

# 做ACC的图
fig = plt.figure(figsize=(8, 6))
plt.plot(list(range(1, 401)), train_acc, linewidth=2, label="Train Acc")
plt.plot(list(range(1, 401)), test_acc, linewidth=2, label="Test Acc")
plt.legend(fontsize=15, loc="lower right")
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.ylabel('Acc(%)', fontsize=15)
plt.xlabel('Epoch', fontsize=15)
plt.tight_layout()
plt.savefig("./Figure/test.png")
exit(0)

# 做Loss的图
fig = plt.figure(figsize=(8, 6))
plt.plot(list(range(1, 401)), train_loss, linewidth=2, label="Train Loss")
plt.plot(list(range(1, 401)), test_loss, linewidth=2, label="Test Loss")
plt.legend(fontsize=15, loc="lower right")
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.ylabel('Loss', fontsize=15)
plt.xlabel('Epoch', fontsize=15)
plt.tight_layout()
plt.savefig("./Figure/test_loss.png")